{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.insert(0, '../')\n",
    "\n",
    "# No GPU because working locally\n",
    "import os\n",
    "os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"\"\n",
    "\n",
    "\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "\n",
    "from gantools import utils\n",
    "from gantools import plot\n",
    "from gantools.model import InpaintingGAN\n",
    "from gantools.gansystem import GANsystem\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "from copy import deepcopy\n",
    "from gantools import blocks\n",
    "\n",
    "from audioinpainting.load import load_audio_dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "downscale = 2\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data handling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Load the data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# dataset = data.load.load_audio_dataset(scaling=downscale)\n",
    "dataset = load_audio_dataset(scaling=downscale, type='piano', spix=1024*16, augmentation=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Number of samples: {}'.format(dataset.N))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The dataset can return an iterator.\n",
    "it = dataset.iter(10)\n",
    "print(next(it).shape)\n",
    "del it"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get all the data\n",
    "X = dataset.get_all_data().flatten()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.hist(X, 100)\n",
    "print('min: {}'.format(np.min(X)))\n",
    "print('max: {}'.format(np.max(X)))\n",
    "plt.yscale('log')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# to free some memory\n",
    "del X"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let us plot 16 samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.audio.plot_signals(dataset.get_samples(N=16),nx=4,ny=4);\n",
    "plt.suptitle(\"Real samples\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.audio.play_sound(dataset.get_samples(16)[0,:], fs=16000//downscale)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Define parameters for the WGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "time_str = 'piano_inpaint'\n",
    "global_path = '../saved_results'\n",
    "\n",
    "name = 'WGAN' + '_' + time_str"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bn = False\n",
    "signal_length = 1024*16\n",
    "signal_split = [1024*6, 1024*4, 1024*6]\n",
    "md = 64\n",
    "\n",
    "params_discriminator = dict()\n",
    "params_discriminator['stride'] = [4,4,4,4,4]\n",
    "params_discriminator['nfilter'] = [md, 2*md, 4*md, 8*md, 16*md]\n",
    "params_discriminator['shape'] = [[25], [25], [25], [25], [25]]\n",
    "params_discriminator['batch_norm'] = [bn, bn, bn, bn, bn]\n",
    "params_discriminator['full'] = [md*4]\n",
    "params_discriminator['minibatch_reg'] = False\n",
    "params_discriminator['summary'] = True\n",
    "params_discriminator['data_size'] = 1\n",
    "params_discriminator['apply_phaseshuffle'] = True \n",
    "params_discriminator['spectral_norm'] = True\n",
    "params_discriminator['activation'] = blocks.lrelu\n",
    "\n",
    "\n",
    "params_generator = dict()\n",
    "params_generator['stride'] = [4, 4, 4, 4, 4]\n",
    "params_generator['latent_dim'] = 100\n",
    "params_generator['nfilter'] = [8*md, 4*md, 2*md, md, 1]\n",
    "params_generator['shape'] = [[25], [25], [25], [25], [25]]\n",
    "params_generator['batch_norm'] = [bn, bn, bn, bn]\n",
    "params_generator['full'] = [64*md]\n",
    "params_generator['summary'] = True\n",
    "params_generator['non_lin'] = tf.nn.tanh\n",
    "params_generator['activation'] = tf.nn.relu\n",
    "params_generator['data_size'] = 1\n",
    "params_generator['spectral_norm'] = True \n",
    "params_generator['in_conv_shape'] =[4]\n",
    "\n",
    "params_generator['borders'] = dict()\n",
    "params_generator['borders']['nfilter'] = [md, 2*md, 4*md, 8*md, 2*md]\n",
    "params_generator['borders']['batch_norm'] = [bn, bn, bn, bn, bn]\n",
    "params_generator['borders']['shape'] = [[25], [25], [25], [25], [25]]\n",
    "params_generator['borders']['stride'] = [4, 4, 4, 4, 4]\n",
    "params_generator['borders']['data_size'] = 1\n",
    "params_generator['borders']['width_full'] = 128\n",
    "params_generator['borders']['activation'] = tf.nn.relu\n",
    "\n",
    "\n",
    "params_optimization = dict()\n",
    "params_optimization['batch_size'] = 64\n",
    "params_optimization['epoch'] = 10000\n",
    "params_optimization['n_critic'] = 5\n",
    "params_optimization['generator'] = dict()\n",
    "params_optimization['generator']['optimizer'] = 'adam'\n",
    "params_optimization['generator']['kwargs'] = {'beta1':0.5, 'beta2':0.9}\n",
    "params_optimization['generator']['learning_rate'] = 1e-4\n",
    "params_optimization['discriminator'] = dict()\n",
    "params_optimization['discriminator']['optimizer'] = 'adam'\n",
    "params_optimization['discriminator']['kwargs'] = {'beta1':0.5, 'beta2':0.9}\n",
    "params_optimization['discriminator']['learning_rate'] = 1e-4\n",
    "\n",
    "\n",
    "# all parameters\n",
    "params = dict()\n",
    "params['net'] = dict() # All the parameters for the model\n",
    "params['net']['generator'] = params_generator\n",
    "params['net']['discriminator'] = params_discriminator\n",
    "params['net']['prior_distribution'] = 'gaussian'\n",
    "params['net']['shape'] = [signal_length, 1] # Shape of the image\n",
    "params['net']['inpainting'] = dict()\n",
    "params['net']['inpainting']['split'] = signal_split\n",
    "params['net']['gamma_gp'] = 10 # Gradient penalty\n",
    "params['net']['fs'] = 16000//downscale\n",
    "params['net']['loss_type'] ='wasserstein'\n",
    "\n",
    "params['optimization'] = params_optimization\n",
    "params['summary_every'] = 100 # Tensorboard summaries every ** iterations\n",
    "params['print_every'] = 50 # Console summaries every ** iterations\n",
    "params['save_every'] = 1000 # Save the model every ** iterations\n",
    "params['summary_dir'] = os.path.join(global_path, name +'_summary/')\n",
    "params['save_dir'] = os.path.join(global_path, name + '_checkpoints/')\n",
    "params['Nstats'] = 100\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "resume, params = utils.test_resume(True, params)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Build the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wgan = GANsystem(InpaintingGAN, params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "wgan.train(dataset, resume=resume)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Generate new samples\n",
    "To have meaningful statistics, be sure to generate enough samples\n",
    "* 2000 : 32 x 32\n",
    "* 500 : 64 x 64\n",
    "* 200 : 128 x 128\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "N = 16 # Number of samples\n",
    "real_signals = dataset.get_samples(N=N)\n",
    "border1 = real_signals[:,:signal_split[0]]\n",
    "border2 = real_signals[:,-signal_split[2]:]\n",
    "borders = np.stack([border1, border2], axis=2)\n",
    "fake_signals = np.squeeze(wgan.generate(N=N, borders=borders))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Display a few fake samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import ltfatpy\n",
    "from ltfatpy import plotdgtreal\n",
    "def plot_sgram(signal, a = 256, M = 512, g='itersine', dynrange=80, **kwargs):\n",
    "    c = ltfatpy.gabor.dgtreal.dgtreal(signal, g, a, M)[0]\n",
    "    return plotdgtreal(c, a, M, dynrange=dynrange,**kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    print('Real')\n",
    "    plot.audio.play_sound(real_signals[i,:], fs=16000//downscale)    \n",
    "    print('Fake')\n",
    "    plot.audio.play_sound(fake_signals[i,:], fs=16000//downscale)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.audio.plot_signals(fake_signals,nx=4,ny=4);\n",
    "plt.suptitle(\"Fake samples\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plot.audio.plot_signals(real_signals,nx=4,ny=4);\n",
    "plt.suptitle(\"Real samples\");"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(4):\n",
    "    plt.figure(figsize=(15, 4))\n",
    "    plt.subplot(121)\n",
    "    plot_sgram(fake_signals[i].astype(np.float64), fs=16000//downscale);\n",
    "    plt.title('Inpainted')\n",
    "    plt.subplot(122)\n",
    "    plot_sgram(real_signals[i].astype(np.float64), fs=16000//downscale);\n",
    "    plt.title('Original')\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
